@@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
272272 if child .module is None :
273273 continue
274274 module = self .get_full_dotted_name (child .module )
275+ if isinstance (child .names , cst .ImportStar ):
276+ continue
275277 for alias in child .names :
276278 if isinstance (alias , cst .ImportAlias ):
277279 name = alias .name .value
@@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
414416 return transformed_module .code
415417
416418
419+ def resolve_star_import (module_name : str , project_root : Path ) -> set [str ]:
420+ try :
421+ module_path = module_name .replace ("." , "/" )
422+ possible_paths = [project_root / f"{ module_path } .py" , project_root / f"{ module_path } /__init__.py" ]
423+
424+ module_file = None
425+ for path in possible_paths :
426+ if path .exists ():
427+ module_file = path
428+ break
429+
430+ if module_file is None :
431+ logger .warning (f"Could not find module file for { module_name } , skipping star import resolution" )
432+ return set ()
433+
434+ with module_file .open (encoding = "utf8" ) as f :
435+ module_code = f .read ()
436+
437+ tree = ast .parse (module_code )
438+
439+ all_names = None
440+ for node in ast .walk (tree ):
441+ if (
442+ isinstance (node , ast .Assign )
443+ and len (node .targets ) == 1
444+ and isinstance (node .targets [0 ], ast .Name )
445+ and node .targets [0 ].id == "__all__"
446+ ):
447+ if isinstance (node .value , (ast .List , ast .Tuple )):
448+ all_names = []
449+ for elt in node .value .elts :
450+ if isinstance (elt , ast .Constant ) and isinstance (elt .value , str ):
451+ all_names .append (elt .value )
452+ elif isinstance (elt , ast .Str ): # Python < 3.8 compatibility
453+ all_names .append (elt .s )
454+ break
455+
456+ if all_names is not None :
457+ return set (all_names )
458+
459+ public_names = set ()
460+ for node in tree .body :
461+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
462+ if not node .name .startswith ("_" ):
463+ public_names .add (node .name )
464+ elif isinstance (node , ast .Assign ):
465+ for target in node .targets :
466+ if isinstance (target , ast .Name ) and not target .id .startswith ("_" ):
467+ public_names .add (target .id )
468+ elif isinstance (node , ast .AnnAssign ):
469+ if isinstance (node .target , ast .Name ) and not node .target .id .startswith ("_" ):
470+ public_names .add (node .target .id )
471+ elif isinstance (node , ast .Import ) or (
472+ isinstance (node , ast .ImportFrom ) and not any (alias .name == "*" for alias in node .names )
473+ ):
474+ for alias in node .names :
475+ name = alias .asname or alias .name
476+ if not name .startswith ("_" ):
477+ public_names .add (name )
478+
479+ return public_names # noqa: TRY300
480+
481+ except Exception as e :
482+ logger .warning (f"Error resolving star import for { module_name } : { e } " )
483+ return set ()
484+
485+
417486def add_needed_imports_from_module (
418487 src_module_code : str ,
419488 dst_module_code : str ,
@@ -468,9 +537,23 @@ def add_needed_imports_from_module(
468537 f"{ mod } .{ obj } " in helper_functions_fqn or dst_context .full_module_name == mod # avoid circular deps
469538 ):
470539 continue # Skip adding imports for helper functions already in the context
471- if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
472- AddImportsVisitor .add_needed_import (dst_context , mod , obj )
473- RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
540+
541+ # Handle star imports by resolving them to actual symbol names
542+ if obj == "*" :
543+ resolved_symbols = resolve_star_import (mod , project_root )
544+ logger .debug (f"Resolved star import from { mod } : { resolved_symbols } " )
545+
546+ for symbol in resolved_symbols :
547+ if (
548+ f"{ mod } .{ symbol } " not in helper_functions_fqn
549+ and f"{ mod } .{ symbol } " not in dotted_import_collector .imports
550+ ):
551+ AddImportsVisitor .add_needed_import (dst_context , mod , symbol )
552+ RemoveImportsVisitor .remove_unused_import (dst_context , mod , symbol )
553+ else :
554+ if f"{ mod } .{ obj } " not in dotted_import_collector .imports :
555+ AddImportsVisitor .add_needed_import (dst_context , mod , obj )
556+ RemoveImportsVisitor .remove_unused_import (dst_context , mod , obj )
474557 except Exception as e :
475558 logger .exception (f"Error adding imports to destination module code: { e } " )
476559 return dst_module_code
0 commit comments