@@ -335,12 +335,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
335335 return updated_node
336336
337337
338- def extract_global_statements (source_code : str ) -> list [cst .SimpleStatementLine ]:
338+ def extract_global_statements (source_code : str ) -> tuple [ cst . Module , list [cst .SimpleStatementLine ] ]:
339339 """Extract global statements from source code."""
340340 module = cst .parse_module (source_code )
341341 collector = GlobalStatementCollector ()
342342 module .visit (collector )
343- return collector .global_statements
343+ return module , collector .global_statements
344344
345345
346346def find_last_import_line (target_code : str ) -> int :
@@ -373,39 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
373373
374374
375375def add_global_assignments (src_module_code : str , dst_module_code : str ) -> str :
376- new_added_global_statements = extract_global_statements (src_module_code )
377- existing_global_statements = extract_global_statements (dst_module_code )
376+ src_module , new_added_global_statements = extract_global_statements (src_module_code )
377+ dst_module , existing_global_statements = extract_global_statements (dst_module_code )
378378
379- # make sure we don't have any staments applited multiple times in the global level.
380- unique_global_statements = [
381- stmt
382- for stmt in new_added_global_statements
383- if not any (stmt .deep_equals (existing_stmt ) for existing_stmt in existing_global_statements )
384- ]
379+ unique_global_statements = []
380+ for stmt in new_added_global_statements :
381+ if any (
382+ stmt is existing_stmt or stmt .deep_equals (existing_stmt ) for existing_stmt in existing_global_statements
383+ ):
384+ continue
385+ unique_global_statements .append (stmt )
385386
387+ mod_dst_code = dst_module_code
388+ # Insert unique global statements if any
386389 if unique_global_statements :
387- # Find the last import line in target
388390 last_import_line = find_last_import_line (dst_module_code )
389-
390- # Parse the target code
391- target_module = cst .parse_module (dst_module_code )
392-
393- # Create transformer to insert new statements
391+ # Reuse already-parsed dst_module
394392 transformer = ImportInserter (unique_global_statements , last_import_line )
395- #
396- # # Apply transformation
397- modified_module = target_module . visit ( transformer )
398- dst_module_code = modified_module . code
399-
400- # Parse the code
401- original_module = cst . parse_module ( dst_module_code )
402- new_module = cst . parse_module ( src_module_code )
393+ # Use visit inplace, don't parse again
394+ modified_module = dst_module . visit ( transformer )
395+ mod_dst_code = modified_module . code
396+ # Parse the code after insertion
397+ original_module = cst . parse_module ( mod_dst_code )
398+ else :
399+ # No new statements to insert, reuse already-parsed dst_module
400+ original_module = dst_module
403401
402+ # Parse the src_module_code once only (already done above: src_module)
404403 # Collect assignments from the new file
405404 new_collector = GlobalAssignmentCollector ()
406- new_module .visit (new_collector )
405+ src_module .visit (new_collector )
406+ # Only create transformer if there are assignments to insert/transform
407+ if not new_collector .assignments : # nothing to transform
408+ return mod_dst_code
407409
408- # Transform the original file
410+ # Transform the original destination module
409411 transformer = GlobalAssignmentTransformer (new_collector .assignments , new_collector .assignment_order )
410412 transformed_module = original_module .visit (transformer )
411413
0 commit comments