Skip to content

Commit 565d65b

Browse files
authored
Merge pull request #776 from codeflash-ai/libcst-importstar-bug
don't iterate over star imports
2 parents b8e01f3 + 6d678aa commit 565d65b

File tree

2 files changed

+228
-3
lines changed

2 files changed

+228
-3
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272272
if child.module is None:
273273
continue
274274
module = self.get_full_dotted_name(child.module)
275+
if isinstance(child.names, cst.ImportStar):
276+
continue
275277
for alias in child.names:
276278
if isinstance(alias, cst.ImportAlias):
277279
name = alias.name.value
@@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
414416
return transformed_module.code
415417

416418

419+
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
420+
try:
421+
module_path = module_name.replace(".", "/")
422+
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
423+
424+
module_file = None
425+
for path in possible_paths:
426+
if path.exists():
427+
module_file = path
428+
break
429+
430+
if module_file is None:
431+
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
432+
return set()
433+
434+
with module_file.open(encoding="utf8") as f:
435+
module_code = f.read()
436+
437+
tree = ast.parse(module_code)
438+
439+
all_names = None
440+
for node in ast.walk(tree):
441+
if (
442+
isinstance(node, ast.Assign)
443+
and len(node.targets) == 1
444+
and isinstance(node.targets[0], ast.Name)
445+
and node.targets[0].id == "__all__"
446+
):
447+
if isinstance(node.value, (ast.List, ast.Tuple)):
448+
all_names = []
449+
for elt in node.value.elts:
450+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
451+
all_names.append(elt.value)
452+
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
453+
all_names.append(elt.s)
454+
break
455+
456+
if all_names is not None:
457+
return set(all_names)
458+
459+
public_names = set()
460+
for node in tree.body:
461+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
462+
if not node.name.startswith("_"):
463+
public_names.add(node.name)
464+
elif isinstance(node, ast.Assign):
465+
for target in node.targets:
466+
if isinstance(target, ast.Name) and not target.id.startswith("_"):
467+
public_names.add(target.id)
468+
elif isinstance(node, ast.AnnAssign):
469+
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
470+
public_names.add(node.target.id)
471+
elif isinstance(node, ast.Import) or (
472+
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
473+
):
474+
for alias in node.names:
475+
name = alias.asname or alias.name
476+
if not name.startswith("_"):
477+
public_names.add(name)
478+
479+
return public_names # noqa: TRY300
480+
481+
except Exception as e:
482+
logger.warning(f"Error resolving star import for {module_name}: {e}")
483+
return set()
484+
485+
417486
def add_needed_imports_from_module(
418487
src_module_code: str,
419488
dst_module_code: str,
@@ -468,9 +537,23 @@ def add_needed_imports_from_module(
468537
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
469538
):
470539
continue # Skip adding imports for helper functions already in the context
471-
if f"{mod}.{obj}" not in dotted_import_collector.imports:
472-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
473-
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
540+
541+
# Handle star imports by resolving them to actual symbol names
542+
if obj == "*":
543+
resolved_symbols = resolve_star_import(mod, project_root)
544+
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
545+
546+
for symbol in resolved_symbols:
547+
if (
548+
f"{mod}.{symbol}" not in helper_functions_fqn
549+
and f"{mod}.{symbol}" not in dotted_import_collector.imports
550+
):
551+
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
552+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
553+
else:
554+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
555+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
556+
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
474557
except Exception as e:
475558
logger.exception(f"Error adding imports to destination module code: {e}")
476559
return dst_module_code

tests/test_add_needed_imports_from_module.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
44
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
55

6+
import tempfile
7+
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
8+
import libcst as cst
9+
from codeflash.models.models import FunctionParent
610

711
def test_add_needed_imports_from_module0() -> None:
812
src_module = '''import ast
@@ -349,3 +353,141 @@ def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[st
349353
project_root_path=Path(__file__).resolve().parent.resolve(),
350354
)
351355
assert new_code == expected
356+
357+
358+
359+
360+
def test_resolve_star_import_with_all_defined():
361+
"""Test resolve_star_import when __all__ is explicitly defined."""
362+
with tempfile.TemporaryDirectory() as tmpdir:
363+
project_root = Path(tmpdir)
364+
test_module = project_root / 'test_module.py'
365+
366+
# Create a test module with __all__ definition
367+
test_module.write_text('''
368+
__all__ = ['public_function', 'PublicClass']
369+
370+
def public_function():
371+
pass
372+
373+
def _private_function():
374+
pass
375+
376+
class PublicClass:
377+
pass
378+
379+
class AnotherPublicClass:
380+
"""Not in __all__ so should be excluded."""
381+
pass
382+
''')
383+
384+
symbols = resolve_star_import('test_module', project_root)
385+
expected_symbols = {'public_function', 'PublicClass'}
386+
assert symbols == expected_symbols
387+
388+
389+
def test_resolve_star_import_without_all_defined():
390+
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
391+
with tempfile.TemporaryDirectory() as tmpdir:
392+
project_root = Path(tmpdir)
393+
test_module = project_root / 'test_module.py'
394+
395+
# Create a test module without __all__ definition
396+
test_module.write_text('''
397+
def public_func():
398+
pass
399+
400+
def _private_func():
401+
pass
402+
403+
class PublicClass:
404+
pass
405+
406+
PUBLIC_VAR = 42
407+
_private_var = 'secret'
408+
''')
409+
410+
symbols = resolve_star_import('test_module', project_root)
411+
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
412+
assert symbols == expected_symbols
413+
414+
415+
def test_resolve_star_import_nonexistent_module():
416+
"""Test resolve_star_import with non-existent module - should return empty set."""
417+
with tempfile.TemporaryDirectory() as tmpdir:
418+
project_root = Path(tmpdir)
419+
420+
symbols = resolve_star_import('nonexistent_module', project_root)
421+
assert symbols == set()
422+
423+
424+
def test_dotted_import_collector_skips_star_imports():
425+
"""Test that DottedImportCollector correctly skips star imports."""
426+
code_with_star_import = '''
427+
from typing import *
428+
from pathlib import Path
429+
from collections import defaultdict
430+
import os
431+
'''
432+
433+
module = cst.parse_module(code_with_star_import)
434+
collector = DottedImportCollector()
435+
module.visit(collector)
436+
437+
# Should collect regular imports but skip the star import
438+
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
439+
assert collector.imports == expected_imports
440+
441+
442+
def test_add_needed_imports_with_star_import_resolution():
443+
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
444+
with tempfile.TemporaryDirectory() as tmpdir:
445+
project_root = Path(tmpdir)
446+
447+
# Create a source module that exports symbols
448+
src_module = project_root / 'source_module.py'
449+
src_module.write_text('''
450+
__all__ = ['UtilFunction', 'HelperClass']
451+
452+
def UtilFunction():
453+
pass
454+
455+
class HelperClass:
456+
pass
457+
''')
458+
459+
# Create source code that uses star import
460+
src_code = '''
461+
from source_module import *
462+
463+
def my_function():
464+
helper = HelperClass()
465+
UtilFunction()
466+
return helper
467+
'''
468+
469+
# Destination code that needs the imports resolved
470+
dst_code = '''
471+
def my_function():
472+
helper = HelperClass()
473+
UtilFunction()
474+
return helper
475+
'''
476+
477+
src_path = project_root / 'src.py'
478+
dst_path = project_root / 'dst.py'
479+
src_path.write_text(src_code)
480+
481+
result = add_needed_imports_from_module(
482+
src_code, dst_code, src_path, dst_path, project_root
483+
)
484+
485+
# The result should have individual imports instead of star import
486+
expected_result = '''from source_module import HelperClass, UtilFunction
487+
488+
def my_function():
489+
helper = HelperClass()
490+
UtilFunction()
491+
return helper
492+
'''
493+
assert result == expected_result

0 commit comments

Comments
 (0)