11import ast
22import hashlib
3- from typing import Dict , Set
43
54
65class VariableNormalizer (ast .NodeTransformer ):
76 """Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
7+
88 Preserves function names, class names, parameters, built-ins, and imported names.
99 """
1010
11- def __init__ (self ):
11+ def __init__ (self ) -> None :
1212 self .var_counter = 0
13- self .var_mapping : Dict [str , str ] = {}
13+ self .var_mapping : dict [str , str ] = {}
1414 self .scope_stack = []
1515 self .builtins = set (dir (__builtins__ ))
16- self .imports : Set [str ] = set ()
17- self .global_vars : Set [str ] = set ()
18- self .nonlocal_vars : Set [str ] = set ()
19- self .parameters : Set [str ] = set () # Track function parameters
16+ self .imports : set [str ] = set ()
17+ self .global_vars : set [str ] = set ()
18+ self .nonlocal_vars : set [str ] = set ()
19+ self .parameters : set [str ] = set () # Track function parameters
2020
21- def enter_scope (self ):
22- """Enter a new scope (function/class)"""
21+ def enter_scope (self ): # noqa : ANN201
22+ """Enter a new scope (function/class). """
2323 self .scope_stack .append (
2424 {"var_mapping" : dict (self .var_mapping ), "var_counter" : self .var_counter , "parameters" : set (self .parameters )}
2525 )
2626
27- def exit_scope (self ):
28- """Exit current scope and restore parent scope"""
27+ def exit_scope (self ): # noqa : ANN201
28+ """Exit current scope and restore parent scope. """
2929 if self .scope_stack :
3030 scope = self .scope_stack .pop ()
3131 self .var_mapping = scope ["var_mapping" ]
3232 self .var_counter = scope ["var_counter" ]
3333 self .parameters = scope ["parameters" ]
3434
3535 def get_normalized_name (self , name : str ) -> str :
36- """Get or create normalized name for a variable"""
36+ """Get or create normalized name for a variable. """
3737 # Don't normalize if it's a builtin, import, global, nonlocal, or parameter
3838 if (
3939 name in self .builtins
@@ -50,34 +50,34 @@ def get_normalized_name(self, name: str) -> str:
5050 self .var_counter += 1
5151 return self .var_mapping [name ]
5252
53- def visit_Import (self , node ):
54- """Track imported names"""
53+ def visit_Import (self , node ): # noqa : ANN001, ANN201
54+ """Track imported names. """
5555 for alias in node .names :
5656 name = alias .asname if alias .asname else alias .name
5757 self .imports .add (name .split ("." )[0 ])
5858 return node
5959
60- def visit_ImportFrom (self , node ):
61- """Track imported names from modules"""
60+ def visit_ImportFrom (self , node ): # noqa : ANN001, ANN201
61+ """Track imported names from modules. """
6262 for alias in node .names :
6363 name = alias .asname if alias .asname else alias .name
6464 self .imports .add (name )
6565 return node
6666
67- def visit_Global (self , node ):
68- """Track global variable declarations"""
67+ def visit_Global (self , node ): # noqa : ANN001, ANN201
68+ """Track global variable declarations. """
6969 # Avoid repeated .add calls by using set.update with list
7070 self .global_vars .update (node .names )
7171 return node
7272
73- def visit_Nonlocal (self , node ):
74- """Track nonlocal variable declarations"""
75- for name in node . names :
76- self .nonlocal_vars .add ( name )
73+ def visit_Nonlocal (self , node ): # noqa : ANN001, ANN201
74+ """Track nonlocal variable declarations. """
75+ # Using set.update for batch insertion (faster than add-in-loop)
76+ self .nonlocal_vars .update ( node . names )
7777 return node
7878
79- def visit_FunctionDef (self , node ):
80- """Process function but keep function name and parameters unchanged"""
79+ def visit_FunctionDef (self , node ): # noqa : ANN001, ANN201
80+ """Process function but keep function name and parameters unchanged. """
8181 self .enter_scope ()
8282
8383 # Track all parameters (don't modify them)
@@ -95,19 +95,19 @@ def visit_FunctionDef(self, node):
9595 self .exit_scope ()
9696 return node
9797
98- def visit_AsyncFunctionDef (self , node ):
99- """Handle async functions same as regular functions"""
98+ def visit_AsyncFunctionDef (self , node ): # noqa : ANN001, ANN201
99+ """Handle async functions same as regular functions. """
100100 return self .visit_FunctionDef (node )
101101
102- def visit_ClassDef (self , node ):
103- """Process class but keep class name unchanged"""
102+ def visit_ClassDef (self , node ): # noqa : ANN001, ANN201
103+ """Process class but keep class name unchanged. """
104104 self .enter_scope ()
105105 node = self .generic_visit (node )
106106 self .exit_scope ()
107107 return node
108108
109- def visit_Name (self , node ):
110- """Normalize variable names in Name nodes"""
109+ def visit_Name (self , node ): # noqa : ANN001, ANN201
110+ """Normalize variable names in Name nodes. """
111111 if isinstance (node .ctx , (ast .Store , ast .Del )):
112112 # For assignments and deletions, check if we should normalize
113113 if (
@@ -118,20 +118,20 @@ def visit_Name(self, node):
118118 and node .id not in self .nonlocal_vars
119119 ):
120120 node .id = self .get_normalized_name (node .id )
121- elif isinstance (node .ctx , ast .Load ):
121+ elif isinstance (node .ctx , ast .Load ): # noqa : SIM102
122122 # For loading, use existing mapping if available
123123 if node .id in self .var_mapping :
124124 node .id = self .var_mapping [node .id ]
125125 return node
126126
127- def visit_ExceptHandler (self , node ):
128- """Normalize exception variable names"""
127+ def visit_ExceptHandler (self , node ): # noqa : ANN001, ANN201
128+ """Normalize exception variable names. """
129129 if node .name :
130130 node .name = self .get_normalized_name (node .name )
131131 return self .generic_visit (node )
132132
133- def visit_comprehension (self , node ):
134- """Normalize comprehension target variables"""
133+ def visit_comprehension (self , node ): # noqa : ANN001, ANN201
134+ """Normalize comprehension target variables. """
135135 # Create new scope for comprehension
136136 old_mapping = dict (self .var_mapping )
137137 old_counter = self .var_counter
@@ -144,23 +144,25 @@ def visit_comprehension(self, node):
144144 self .var_counter = old_counter
145145 return node
146146
147- def visit_For (self , node ):
148- """Handle for loop target variables"""
147+ def visit_For (self , node ): # noqa : ANN001, ANN201
148+ """Handle for loop target variables. """
149149 # The target in a for loop is a local variable that should be normalized
150150 return self .generic_visit (node )
151151
152- def visit_With (self , node ):
153- """Handle with statement as variables"""
152+ def visit_With (self , node ): # noqa : ANN001, ANN201
153+ """Handle with statement as variables. """
154154 return self .generic_visit (node )
155155
156156
157- def normalize_code (code : str , remove_docstrings : bool = True , return_ast_dump : bool = False ) -> str :
157+ def normalize_code (code : str , remove_docstrings : bool = True , return_ast_dump : bool = False ) -> str : # noqa : FBT002, FBT001
158158 """Normalize Python code by parsing, cleaning, and normalizing only variable names.
159+
159160 Function names, class names, and parameters are preserved.
160161
161162 Args:
162163 code: Python source code as string
163164 remove_docstrings: Whether to remove docstrings
165+ return_ast_dump: return_ast_dump
164166
165167 Returns:
166168 Normalized code as string
@@ -191,7 +193,7 @@ def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: b
191193 raise ValueError (msg ) from e
192194
193195
194- def remove_docstrings_from_ast (node ):
196+ def remove_docstrings_from_ast (node ): # noqa : ANN001, ANN201
195197 """Remove docstrings from AST nodes."""
196198 # Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
197199 node_types = (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef , ast .Module )
@@ -242,6 +244,7 @@ def are_codes_duplicate(code1: str, code2: str) -> bool:
242244 try :
243245 normalized1 = normalize_code (code1 , return_ast_dump = True )
244246 normalized2 = normalize_code (code2 , return_ast_dump = True )
245- return normalized1 == normalized2
246247 except Exception :
247248 return False
249+ else :
250+ return normalized1 == normalized2
0 commit comments