55from dataclasses import dataclass , field
66from itertools import chain
77from pathlib import Path
8- from typing import TYPE_CHECKING , Optional
8+ from typing import TYPE_CHECKING , Optional , Union
99
1010import libcst as cst
1111
@@ -52,46 +52,64 @@ def collect_top_level_definitions(
5252 node : cst .CSTNode , definitions : Optional [dict [str , UsageInfo ]] = None
5353) -> dict [str , UsageInfo ]:
5454 """Recursively collect all top-level variable, function, and class definitions."""
55+ # Locally bind types and helpers for faster lookup
56+ FunctionDef = cst .FunctionDef # noqa: N806
57+ ClassDef = cst .ClassDef # noqa: N806
58+ Assign = cst .Assign # noqa: N806
59+ AnnAssign = cst .AnnAssign # noqa: N806
60+ AugAssign = cst .AugAssign # noqa: N806
61+ IndentedBlock = cst .IndentedBlock # noqa: N806
62+
5563 if definitions is None :
5664 definitions = {}
5765
58- # Handle top-level function definitions
59- if isinstance (node , cst .FunctionDef ):
66+ # Speed: Single isinstance+local var instead of several type calls
67+ node_type = type (node )
68+ # Fast path: function def
69+ if node_type is FunctionDef :
6070 name = node .name .value
6171 definitions [name ] = UsageInfo (
6272 name = name ,
6373 used_by_qualified_function = False , # Will be marked later if in qualified functions
6474 )
6575 return definitions
6676
67- # Handle top-level class definitions
68- if isinstance ( node , cst . ClassDef ) :
77+ # Fast path: class def
78+ if node_type is ClassDef :
6979 name = node .name .value
7080 definitions [name ] = UsageInfo (name = name )
7181
72- # Also collect method definitions within the class
73- if hasattr (node , "body" ) and isinstance (node .body , cst .IndentedBlock ):
74- for statement in node .body .body :
75- if isinstance (statement , cst .FunctionDef ):
76- method_name = f"{ name } .{ statement .name .value } "
82+ # Collect class methods
83+ body = getattr (node , "body" , None )
84+ if body is not None and type (body ) is IndentedBlock :
85+ statements = body .body
86+ # Precompute f-string template for efficiency
87+ prefix = name + "."
88+ for statement in statements :
89+ if type (statement ) is FunctionDef :
90+ method_name = prefix + statement .name .value
7791 definitions [method_name ] = UsageInfo (name = method_name )
7892
7993 return definitions
8094
81- # Handle top-level variable assignments
82- if isinstance (node , cst .Assign ):
83- for target in node .targets :
95+ # Fast path: assignment
96+ if node_type is Assign :
97+ # Inline extract_names_from_targets for single-target speed
98+ targets = node .targets
99+ append_def = definitions .__setitem__
100+ for target in targets :
84101 names = extract_names_from_targets (target .target )
85102 for name in names :
86- definitions [ name ] = UsageInfo (name = name )
103+ append_def ( name , UsageInfo (name = name ) )
87104 return definitions
88105
89- if isinstance (node , (cst .AnnAssign , cst .AugAssign )):
90- if isinstance (node .target , cst .Name ):
91- name = node .target .value
106+ if node_type is AnnAssign or node_type is AugAssign :
107+ tgt = node .target
108+ if type (tgt ) is cst .Name :
109+ name = tgt .value
92110 definitions [name ] = UsageInfo (name = name )
93111 else :
94- names = extract_names_from_targets (node . target )
112+ names = extract_names_from_targets (tgt )
95113 for name in names :
96114 definitions [name ] = UsageInfo (name = name )
97115 return definitions
@@ -100,12 +118,15 @@ def collect_top_level_definitions(
100118 section_names = get_section_names (node )
101119
102120 if section_names :
121+ getattr_ = getattr
103122 for section in section_names :
104- original_content = getattr (node , section , None )
123+ original_content = getattr_ (node , section , None )
124+ # Instead of isinstance check for list/tuple, rely on duck-type via iter
105125 # If section contains a list of nodes
106126 if isinstance (original_content , (list , tuple )):
127+ defs = definitions # Move out for minor speed
107128 for child in original_content :
108- collect_top_level_definitions (child , definitions )
129+ collect_top_level_definitions (child , defs )
109130 # If section contains a single node
110131 elif original_content is not None :
111132 collect_top_level_definitions (original_content , definitions )
@@ -122,6 +143,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
122143class DependencyCollector (cst .CSTVisitor ):
123144 """Collects dependencies between definitions using the visitor pattern with depth tracking."""
124145
146+ METADATA_DEPENDENCIES = (cst .metadata .ParentNodeProvider ,)
147+
125148 def __init__ (self , definitions : dict [str , UsageInfo ]) -> None :
126149 super ().__init__ ()
127150 self .definitions = definitions
@@ -259,8 +282,12 @@ def visit_Name(self, node: cst.Name) -> None:
259282 if self .processing_variable and name in self .current_variable_names :
260283 return
261284
262- # Check if name is a top-level definition we're tracking
263285 if name in self .definitions and name != self .current_top_level_name :
286+ # skip if we are refrencing a class attribute and not a top-level definition
287+ if self .class_depth > 0 :
288+ parent = self .get_metadata (cst .metadata .ParentNodeProvider , node )
289+ if parent is not None and isinstance (parent , cst .Attribute ):
290+ return
264291 self .definitions [self .current_top_level_name ].dependencies .add (name )
265292
266293
@@ -293,13 +320,20 @@ def _expand_qualified_functions(self) -> set[str]:
293320
294321 def mark_used_definitions (self ) -> None :
295322 """Find all qualified functions and mark them and their dependencies as used."""
296- # First identify all specified functions (including expanded ones)
297- functions_to_mark = [name for name in self .expanded_qualified_functions if name in self .definitions ]
323+ # Avoid list comprehension for set intersection
324+ expanded_names = self .expanded_qualified_functions
325+ defs = self .definitions
326+ # Use set intersection but only if defs.keys is a set (Python 3.12 dict_keys supports it efficiently)
327+ fnames = (
328+ expanded_names & defs .keys ()
329+ if isinstance (expanded_names , set )
330+ else [name for name in expanded_names if name in defs ]
331+ )
298332
299333 # For each specified function, mark it and all its dependencies as used
300- for func_name in functions_to_mark :
301- self . definitions [func_name ].used_by_qualified_function = True
302- for dep in self . definitions [func_name ].dependencies :
334+ for func_name in fnames :
335+ defs [func_name ].used_by_qualified_function = True
336+ for dep in defs [func_name ].dependencies :
303337 self .mark_as_used_recursively (dep )
304338
305339 def mark_as_used_recursively (self , name : str ) -> None :
@@ -457,6 +491,25 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
457491 return node , False
458492
459493
494+ def collect_top_level_defs_with_usages (
495+ code : Union [str , cst .Module ], qualified_function_names : set [str ]
496+ ) -> dict [str , UsageInfo ]:
497+ """Collect all top level definitions (classes, variables or functions) and their usages."""
498+ module = code if isinstance (code , cst .Module ) else cst .parse_module (code )
499+ # Collect all definitions (top level classes, variables or function)
500+ definitions = collect_top_level_definitions (module )
501+
502+ # Collect dependencies between definitions using the visitor pattern
503+ wrapper = cst .MetadataWrapper (module )
504+ dependency_collector = DependencyCollector (definitions )
505+ wrapper .visit (dependency_collector )
506+
507+ # Mark definitions used by specified functions, and their dependencies recursively
508+ usage_marker = QualifiedFunctionUsageMarker (definitions , qualified_function_names )
509+ usage_marker .mark_used_definitions ()
510+ return definitions
511+
512+
460513def remove_unused_definitions_by_function_names (code : str , qualified_function_names : set [str ]) -> str :
461514 """Analyze a file and remove top level definitions not used by specified functions.
462515
@@ -476,19 +529,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
476529 return code
477530
478531 try :
479- # Collect all definitions (top level classes, variables or function)
480- definitions = collect_top_level_definitions (module )
481-
482- # Collect dependencies between definitions using the visitor pattern
483- dependency_collector = DependencyCollector (definitions )
484- module .visit (dependency_collector )
485-
486- # Mark definitions used by specified functions, and their dependencies recursively
487- usage_marker = QualifiedFunctionUsageMarker (definitions , qualified_function_names )
488- usage_marker .mark_used_definitions ()
532+ defs_with_usages = collect_top_level_defs_with_usages (module , qualified_function_names )
489533
490534 # Apply the recursive removal transformation
491- modified_module , _ = remove_unused_definitions_recursively (module , definitions )
535+ modified_module , _ = remove_unused_definitions_recursively (module , defs_with_usages )
492536
493537 return modified_module .code if modified_module else "" # noqa: TRY300
494538 except Exception as e :
0 commit comments