@@ -746,7 +746,11 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
746746 return CodeStringsMarkdown (code_strings = [])
747747
748748 imported_names : dict [str , str ] = {}
749- external_bases : list [tuple [str , str ]] = []
749+ # Use a set to deduplicate external base entries to avoid repeated expensive checks/imports.
750+ external_bases_set : set [tuple [str , str ]] = set ()
751+ # Local cache to avoid repeated _is_project_module calls for the same module_name.
752+ is_project_cache : dict [str , bool ] = {}
753+
750754 for node in ast .walk (tree ):
751755 if isinstance (node , ast .ImportFrom ) and node .module :
752756 for alias in node .names :
@@ -763,21 +767,31 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
763767
764768 if base_name and base_name in imported_names :
765769 module_name = imported_names [base_name ]
766- if not _is_project_module (module_name , project_root_path ):
767- external_bases .append ((base_name , module_name ))
768-
769- if not external_bases :
770+ # Check cache first to avoid repeated expensive checks.
771+ cached = is_project_cache .get (module_name )
772+ if cached is None :
773+ is_project = _is_project_module (module_name , project_root_path )
774+ is_project_cache [module_name ] = is_project
775+ else :
776+ is_project = cached
777+
778+ if not is_project :
779+ external_bases_set .add ((base_name , module_name ))
780+
781+ if not external_bases_set :
770782 return CodeStringsMarkdown (code_strings = [])
771783
772784 code_strings : list [CodeString ] = []
773- extracted : set [tuple [str , str ]] = set ()
774-
775- for base_name , module_name in external_bases :
776- if (module_name , base_name ) in extracted :
777- continue
785+ # Cache imported modules to avoid repeated importlib.import_module calls.
786+ imported_module_cache : dict [str , object ] = {}
778787
788+ for base_name , module_name in external_bases_set :
779789 try :
780- module = importlib .import_module (module_name )
790+ module = imported_module_cache .get (module_name )
791+ if module is None :
792+ module = importlib .import_module (module_name )
793+ imported_module_cache [module_name ] = module
794+
781795 base_class = getattr (module , base_name , None )
782796 if base_class is None :
783797 continue
@@ -799,7 +813,6 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
799813
800814 class_source = f"class { base_name } :\n " + textwrap .indent (init_source , " " )
801815 code_strings .append (CodeString (code = class_source , file_path = class_file ))
802- extracted .add ((module_name , base_name ))
803816
804817 except (ImportError , ModuleNotFoundError , AttributeError ):
805818 logger .debug (f"Failed to extract __init__ for { module_name } .{ base_name } " )
@@ -854,12 +867,13 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
854867 needed_names .add (decorator .func .value .id )
855868
856869 # Get type annotation names from class body (for dataclass fields)
857- for item in ast . walk ( class_node ) :
870+ for item in class_node . body :
858871 if isinstance (item , ast .AnnAssign ) and item .annotation :
859872 collect_names_from_annotation (item .annotation , needed_names )
860873 # Also check for field() calls which are common in dataclasses
861- if isinstance (item , ast .Call ) and isinstance (item .func , ast .Name ):
862- needed_names .add (item .func .id )
874+ elif isinstance (item , ast .Assign ) and isinstance (item .value , ast .Call ):
875+ if isinstance (item .value .func , ast .Name ):
876+ needed_names .add (item .value .func .id )
863877
864878 # Find imports that provide these names
865879 import_lines : list [str ] = []
0 commit comments