Skip to content

Commit 1c0d8da

Browse files
authored
Merge pull request #1339 from codeflash-ai/coverage-no-files
Skip when no gen tests and no existing tests
2 parents dfe073a + daf570b commit 1c0d8da

File tree

9 files changed

+75
-41
lines changed

9 files changed

+75
-41
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,9 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
14291429
numerical_names: set[str] = set()
14301430
modules_used: set[str] = set()
14311431

1432-
for node in ast.walk(tree):
1432+
stack: list[ast.AST] = [tree]
1433+
while stack:
1434+
node = stack.pop()
14331435
if isinstance(node, ast.Import):
14341436
for alias in node.names:
14351437
# import numpy or import numpy as np
@@ -1451,6 +1453,8 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
14511453
name = alias.asname if alias.asname else alias.name
14521454
numerical_names.add(name)
14531455
modules_used.add(module_root)
1456+
else:
1457+
stack.extend(ast.iter_child_nodes(node))
14541458

14551459
return numerical_names, modules_used
14561460

codeflash/context/code_context_extractor.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,11 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
746746
return CodeStringsMarkdown(code_strings=[])
747747

748748
imported_names: dict[str, str] = {}
749-
external_bases: list[tuple[str, str]] = []
749+
# Use a set to deduplicate external base entries to avoid repeated expensive checks/imports.
750+
external_bases_set: set[tuple[str, str]] = set()
751+
# Local cache to avoid repeated _is_project_module calls for the same module_name.
752+
is_project_cache: dict[str, bool] = {}
753+
750754
for node in ast.walk(tree):
751755
if isinstance(node, ast.ImportFrom) and node.module:
752756
for alias in node.names:
@@ -763,21 +767,31 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
763767

764768
if base_name and base_name in imported_names:
765769
module_name = imported_names[base_name]
766-
if not _is_project_module(module_name, project_root_path):
767-
external_bases.append((base_name, module_name))
768-
769-
if not external_bases:
770+
# Check cache first to avoid repeated expensive checks.
771+
cached = is_project_cache.get(module_name)
772+
if cached is None:
773+
is_project = _is_project_module(module_name, project_root_path)
774+
is_project_cache[module_name] = is_project
775+
else:
776+
is_project = cached
777+
778+
if not is_project:
779+
external_bases_set.add((base_name, module_name))
780+
781+
if not external_bases_set:
770782
return CodeStringsMarkdown(code_strings=[])
771783

772784
code_strings: list[CodeString] = []
773-
extracted: set[tuple[str, str]] = set()
774-
775-
for base_name, module_name in external_bases:
776-
if (module_name, base_name) in extracted:
777-
continue
785+
# Cache imported modules to avoid repeated importlib.import_module calls.
786+
imported_module_cache: dict[str, object] = {}
778787

788+
for base_name, module_name in external_bases_set:
779789
try:
780-
module = importlib.import_module(module_name)
790+
module = imported_module_cache.get(module_name)
791+
if module is None:
792+
module = importlib.import_module(module_name)
793+
imported_module_cache[module_name] = module
794+
781795
base_class = getattr(module, base_name, None)
782796
if base_class is None:
783797
continue
@@ -799,7 +813,6 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
799813

800814
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
801815
code_strings.append(CodeString(code=class_source, file_path=class_file))
802-
extracted.add((module_name, base_name))
803816

804817
except (ImportError, ModuleNotFoundError, AttributeError):
805818
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
@@ -854,12 +867,13 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
854867
needed_names.add(decorator.func.value.id)
855868

856869
# Get type annotation names from class body (for dataclass fields)
857-
for item in ast.walk(class_node):
870+
for item in class_node.body:
858871
if isinstance(item, ast.AnnAssign) and item.annotation:
859872
collect_names_from_annotation(item.annotation, needed_names)
860873
# Also check for field() calls which are common in dataclasses
861-
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
862-
needed_names.add(item.func.id)
874+
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
875+
if isinstance(item.value.func, ast.Name):
876+
needed_names.add(item.value.func.id)
863877

864878
# Find imports that provide these names
865879
import_lines: list[str] = []

codeflash/discovery/discover_unit_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def discover_unit_tests(
656656

657657
# Existing Python logic
658658
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
659-
strategy = framework_strategies.get(cfg.test_framework, None)
659+
strategy = framework_strategies.get(cfg.test_framework)
660660
if not strategy:
661661
error_message = f"Unsupported test framework: {cfg.test_framework}"
662662
raise ValueError(error_message)

codeflash/github/PrComment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ class PrComment:
2525
best_async_throughput: Optional[int] = None
2626

2727
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
28-
report_table = {
29-
test_type.to_name(): result
30-
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
31-
if test_type.to_name()
32-
}
28+
report_table: dict[str, dict[str, int]] = {}
29+
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
30+
name = test_type.to_name()
31+
if name:
32+
report_table[name] = result
3333

3434
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
3535
"optimization_explanation": self.optimization_explanation,

codeflash/languages/javascript/find_references.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -651,15 +651,18 @@ def _find_reexports(
651651
652652
"""
653653
references: list[Reference] = []
654+
export_name = exported.export_name or exported.function_name
655+
656+
# Skip expensive parsing if export name not in source
657+
if export_name not in source_code:
658+
return references
659+
654660
exports = analyzer.find_exports(source_code)
655661
lines = source_code.splitlines()
656662

657663
for exp in exports:
658664
if not exp.is_reexport:
659665
continue
660-
661-
# Check if this re-exports our function
662-
export_name = exported.export_name or exported.function_name
663666
for name, alias in exp.exported_names:
664667
if name == export_name:
665668
# This is a re-export of our function

codeflash/models/models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from collections.abc import Iterator
18+
1819
import enum
1920
import re
2021
import sys
@@ -876,15 +877,14 @@ def number_of_loops(self) -> int:
876877
return max(test_result.loop_index for test_result in self.test_results)
877878

878879
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
879-
report = {}
880-
for test_type in TestType:
881-
report[test_type] = {"passed": 0, "failed": 0}
880+
report: dict[TestType, dict[str, int]] = {tt: {"passed": 0, "failed": 0} for tt in TestType}
882881
for test_result in self.test_results:
883-
if test_result.loop_index == 1:
884-
if test_result.did_pass:
885-
report[test_result.test_type]["passed"] += 1
886-
else:
887-
report[test_result.test_type]["failed"] += 1
882+
if test_result.loop_index != 1:
883+
continue
884+
if test_result.did_pass:
885+
report[test_result.test_type]["passed"] += 1
886+
else:
887+
report[test_result.test_type]["failed"] += 1
888888
return report
889889

890890
@staticmethod

codeflash/models/test_type.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ class TestType(Enum):
1212
def to_name(self) -> str:
1313
if self is TestType.INIT_STATE_TEST:
1414
return ""
15-
names = {
16-
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
17-
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
18-
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
19-
TestType.REPLAY_TEST: "⏪ Replay Tests",
20-
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
21-
}
22-
return names[self]
15+
return _TO_NAME_MAP[self]
16+
17+
18+
_TO_NAME_MAP: dict[TestType, str] = {
19+
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
20+
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
21+
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
22+
TestType.REPLAY_TEST: "⏪ Replay Tests",
23+
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
24+
}

codeflash/optimization/function_optimizer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,16 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P
496496
should_run_experiment = self.experiment_id is not None
497497
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
498498
ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id})
499+
500+
# Early check: if --no-gen-tests is set, verify there are existing tests for this function
501+
if self.args.no_gen_tests:
502+
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
503+
if not self.function_to_tests.get(func_qualname):
504+
return Failure(
505+
f"No existing tests found for '{self.function_to_optimize.function_name}'. "
506+
f"Cannot optimize without tests when --no-gen-tests is set."
507+
)
508+
499509
self.cleanup_leftover_test_return_values()
500510
file_name_from_test_module_name.cache_clear()
501511
ctx_result = self.get_code_optimization_context()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ ignore = [
272272
"ANN401", # typing.Any disallowed
273273
"ARG001", # Unused function argument (common in abstract/interface methods)
274274
"TRY300", # Consider moving to else block
275+
"FURB110", # if-exp-instead-of-or-operator - we prefer explicit if-else over "or"
275276
"TRY401", # Redundant exception in logging.exception
276277
"PLR0911", # Too many return statements
277278
"PLW0603", # Global statement

0 commit comments

Comments
 (0)